{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementing an Algorithm\n",
    "\n",
    "In this tutorial we'll build a new agent that implements the Categorical Deep Q Network (C51) algorithm (https://arxiv.org/pdf/1707.06887.pdf), and a preset that runs the agent on the 'Breakout' game of the Atari environment.\n",
    "\n",
    "Implementing an algorithm typically consists of 3 main parts:\n",
    "\n",
    "1. Implementing the agent object\n",
    "2. Implementing the network head (optional)\n",
    "3. Implementing a preset to run the agent on some environment\n",
    "\n",
    "The entire agent can be defined outside of the Coach framework, but in Coach you can find multiple predefined agents under the `agents` directory, network heads under the `architecure/tensorflow_components/heads` directory, and presets under the `presets` directory, for you to reuse.\n",
    "\n",
    "For more information, we recommend going over the following page in the documentation: https://nervanasystems.github.io/coach/contributing/add_agent/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Network Head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll start by defining a new head for the neural network used by this algorithm - ```CategoricalQHead```. \n",
    "\n",
    "A head is the final part of the network. It takes the embedding from the middleware embedder and passes it through a neural network to produce the output of the network. There can be multiple heads in a network, and each one has an assigned loss function. The heads are algorithm dependent.\n",
    "\n",
    "The rest of the network can be reused from the predefined parts, and the input embedder and middleware structure can also be modified, but we won't go into that in this tutorial.\n",
    "\n",
    "The head will typically be defined in a new file - ```architectures/tensorflow_components/heads/categorical_dqn_head.py```.\n",
    "\n",
    "First - some imports."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)\n",
    "\n",
    "import tensorflow as tf\n",
    "from rl_coach.architectures.tensorflow_components.heads.head import Head\n",
    "from rl_coach.architectures.head_parameters import HeadParameters\n",
    "from rl_coach.base_parameters import AgentParameters\n",
    "from rl_coach.core_types import QActionStateValue\n",
    "from rl_coach.spaces import SpacesDefinition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's define a class - ```CategoricalQHead``` class. Each class in Coach has a complementary Parameters class which defines its constructor parameters. So we will additionally define the ```CategoricalQHeadParameters``` class. The network structure should be defined in the `_build_module` function, which gets the previous layer output as an argument. In this function there are several variables that should be defined:\n",
    "* `self.input` - (optional) a list of any additional input to the head\n",
    "* `self.output` - the output of the head, which is also one of the outputs of the network\n",
    "* `self.target` - a placeholder for the targets that will be used to train the network\n",
    "* `self.regularizations` - (optional) any additional regularization losses that will be applied to the network\n",
    "* `self.loss` - the loss that will be used to train the network\n",
    "\n",
    "Categorical DQN uses the same network as DQN, and only changes the last layer to output #actions x #atoms elements with a softmax function. Additionally, we update the loss function to cross entropy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CategoricalQHeadParameters(HeadParameters):\n",
    "    def __init__(self, activation_function: str ='relu', name: str='categorical_q_head_params'):\n",
    "        super().__init__(parameterized_class=CategoricalQHead, activation_function=activation_function, name=name)\n",
    "\n",
    "class CategoricalQHead(Head):\n",
    "    def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,\n",
    "                 head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str ='relu'):\n",
    "        super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)\n",
    "        self.name = 'categorical_dqn_head'\n",
    "        self.num_actions = len(self.spaces.action.actions)\n",
    "        self.num_atoms = agent_parameters.algorithm.atoms\n",
    "        self.return_type = QActionStateValue\n",
    "\n",
    "    def _build_module(self, input_layer):\n",
    "        self.actions = tf.placeholder(tf.int32, [None], name=\"actions\")\n",
    "        self.input = [self.actions]\n",
    "\n",
    "        values_distribution = tf.layers.dense(input_layer, self.num_actions * self.num_atoms, name='output')\n",
    "        values_distribution = tf.reshape(values_distribution, (tf.shape(values_distribution)[0], self.num_actions,\n",
    "                                                               self.num_atoms))\n",
    "        # softmax on atoms dimension\n",
    "        self.output = tf.nn.softmax(values_distribution)\n",
    "\n",
    "        # calculate cross entropy loss\n",
    "        self.distributions = tf.placeholder(tf.float32, shape=(None, self.num_actions, self.num_atoms),\n",
    "                                            name=\"distributions\")\n",
    "        self.target = self.distributions\n",
    "        self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)\n",
    "        tf.losses.add_loss(self.loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Agent\n",
    "\n",
    "The agent will implement the Categorical DQN algorithm. Each agent has a complementary ```AgentParameters``` class, which allows selecting the parameters of the agent sub modules: \n",
    "* the **algorithm**\n",
    "* the **exploration policy**\n",
    "* the **memory**\n",
    "* the **networks**\n",
    "\n",
    "Now let's go ahead and define the network parameters - it will reuse the DQN network parameters but the head parameters will be our ```CategoricalQHeadParameters```. The network parameters allows selecting any number of heads for the network by defining them in a list, but in this case we only have a single head, so we will point to its parameters class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rl_coach.agents.dqn_agent import DQNNetworkParameters\n",
    "\n",
    "\n",
    "class CategoricalDQNNetworkParameters(DQNNetworkParameters):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.heads_parameters = [CategoricalQHeadParameters()]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we'll define the algorithm parameters, which are the same as the DQN algorithm parameters, with the addition of the Categorical DQN specific `v_min`, `v_max` and number of atoms.\n",
    "We'll also define the parameters of the exploration policy, which is epsilon greedy with epsilon starting at a value of 1.0 and decaying to 0.01 throughout 1,000,000 steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rl_coach.agents.dqn_agent import DQNAlgorithmParameters\n",
    "from rl_coach.exploration_policies.e_greedy import EGreedyParameters\n",
    "from rl_coach.schedules import LinearSchedule\n",
    "\n",
    "\n",
    "class CategoricalDQNAlgorithmParameters(DQNAlgorithmParameters):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.v_min = -10.0\n",
    "        self.v_max = 10.0\n",
    "        self.atoms = 51\n",
    "\n",
    "\n",
    "class CategoricalDQNExplorationParameters(EGreedyParameters):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.epsilon_schedule = LinearSchedule(1, 0.01, 1000000)\n",
    "        self.evaluation_epsilon = 0.001 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's define the agent parameters class which contains all the parameters to be used by the agent - the network, algorithm and exploration parameters that we defined above, and also the parameters of the memory module to be used, which is the default experience replay buffer in this case. \n",
    "Notice that the networks are defined as a dictionary, where the key is the name of the network and the value is the network parameters. This will allow us to later access each of the networks through `self.networks[network_name]`.\n",
    "\n",
    "The `path` property connects the parameters class to its corresponding class that is parameterized. In this case, it is the `CategoricalDQNAgent` class that we'll define in a moment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent\n",
    "from rl_coach.base_parameters import AgentParameters\n",
    "from rl_coach.core_types import StateType\n",
    "from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters\n",
    "\n",
    "\n",
    "class CategoricalDQNAgentParameters(AgentParameters):\n",
    "    def __init__(self):\n",
    "        super().__init__(algorithm=CategoricalDQNAlgorithmParameters(),\n",
    "                         exploration=CategoricalDQNExplorationParameters(),\n",
    "                         memory=ExperienceReplayParameters(),\n",
    "                         networks={\"main\": CategoricalDQNNetworkParameters()})\n",
    "\n",
    "    @property\n",
    "    def path(self):\n",
    "        return 'agents.categorical_dqn_agent:CategoricalDQNAgent'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The last step is to define the agent itself - ```CategoricalDQNAgent``` - which is a type of value optimization agent so it will inherit the ```ValueOptimizationAgent``` class. It could have also inheritted ```DQNAgent```, which would result in the same functionality. Our agent will implement the ```learn_from_batch``` function which updates the agent's networks according to an input batch of transitions.\n",
    "\n",
    "Agents typically need to implement the training function - `learn_from_batch`, and a function that defines which actions to select given a state - `choose_action`. In our case, we will reuse the `choose_action` function implemented by the generic `ValueOptimizationAgent`, and just update the internal function for fetching q values for each of the actions - `get_all_q_values_for_states`.\n",
    "\n",
    "This code may look intimidating at first glance, but basically it is just following the algorithm description in the Distributional DQN paper:\n",
    "<img src=\"files/categorical_dqn.png\" width=400>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Union\n",
    "\n",
    "\n",
    "# Categorical Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf\n",
    "class CategoricalDQNAgent(ValueOptimizationAgent):\n",
    "    def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):\n",
    "        super().__init__(agent_parameters, parent)\n",
    "        self.z_values = np.linspace(self.ap.algorithm.v_min, self.ap.algorithm.v_max, self.ap.algorithm.atoms)\n",
    "\n",
    "    def distribution_prediction_to_q_values(self, prediction):\n",
    "        return np.dot(prediction, self.z_values)\n",
    "\n",
    "    # prediction's format is (batch,actions,atoms)\n",
    "    def get_all_q_values_for_states(self, states: StateType):\n",
    "        prediction = self.get_prediction(states)\n",
    "        return self.distribution_prediction_to_q_values(prediction)\n",
    "\n",
    "    def learn_from_batch(self, batch):\n",
    "        network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()\n",
    "\n",
    "        # for the action we actually took, the error is calculated by the atoms distribution\n",
    "        # for all other actions, the error is 0\n",
    "        distributed_q_st_plus_1, TD_targets = self.networks['main'].parallel_prediction([\n",
    "            (self.networks['main'].target_network, batch.next_states(network_keys)),\n",
    "            (self.networks['main'].online_network, batch.states(network_keys))\n",
    "        ])\n",
    "\n",
    "        # only update the action that we have actually done in this transition\n",
    "        target_actions = np.argmax(self.distribution_prediction_to_q_values(distributed_q_st_plus_1), axis=1)\n",
    "        m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))\n",
    "\n",
    "        batches = np.arange(self.ap.network_wrappers['main'].batch_size)\n",
    "        for j in range(self.z_values.size):\n",
    "            tzj = np.fmax(np.fmin(batch.rewards() +\n",
    "                                  (1.0 - batch.game_overs()) * self.ap.algorithm.discount * self.z_values[j],\n",
    "                                  self.z_values[self.z_values.size - 1]),\n",
    "                          self.z_values[0])\n",
    "            bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0])\n",
    "            u = (np.ceil(bj)).astype(int)\n",
    "            l = (np.floor(bj)).astype(int)\n",
    "            m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj))\n",
    "            m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l))\n",
    "        # total_loss = cross entropy between actual result above and predicted result for the given action\n",
    "        TD_targets[batches, batch.actions()] = m\n",
    "\n",
    "        result = self.networks['main'].train_and_sync_networks(batch.states(network_keys), TD_targets)\n",
    "        total_loss, losses, unclipped_grads = result[:3]\n",
    "\n",
    "        return total_loss, losses, unclipped_grads"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Some important things to notice here:\n",
    "* `self.networks['main']` is a NetworkWrapper object. It holds all the copies of the 'main' network: \n",
    "    - a **global network** which is shared between all the workers in distributed training\n",
    "    - an **online network** which is a local copy of the network intended to keep the weights static between training steps\n",
    "    - a **target network** which is a local slow updating copy of the network, and is intended to keep the targets of the training process more stable\n",
    "  In this case, we have the online network and the target network. The global network will only be created if we run the algorithm with multiple workers. The A3C agent would be one kind of example. \n",
    "* There are two network prediction functions available - `predict` and `parallel_prediction`. `predict` is quite straightforward - get some inputs, forward them through the network and return the output. `parallel_prediction` is an optimized variant of `predict`, which allows running a prediction on the online and target network in parallel, instead of running them sequentially.\n",
    "* The network `train_and_sync_networks` function makes a single training step - running a forward pass of the online network, calculating the losses, running a backward pass to calculate the gradients and applying the gradients to the network weights. If multiple workers are used, instead of applying the gradients to the online network weights, they are applied to the global (shared) network weights, and then the weights are copied back to the online network."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Preset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The final part is the preset, which will run our agent on some existing environment with any custom parameters.\n",
    "\n",
    "The new preset will be typically be defined in a new file - ```presets/atari_categorical_dqn.py```.\n",
    "\n",
    "First - let's select the agent parameters we defined above. \n",
    "It is possible to modify internal parameters such as the learning rate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rl_coach.agents.categorical_dqn_agent import CategoricalDQNAgentParameters\n",
    "\n",
    "\n",
    "agent_params = CategoricalDQNAgentParameters()\n",
    "agent_params.network_wrappers['main'].learning_rate = 0.00025"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's define the environment parameters. We will use the default Atari parameters (frame skip of 4, taking the max over subsequent frames, etc.), and we will select the 'Breakout' game level."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4\n",
    "\n",
    "\n",
    "env_params = Atari(level='BreakoutDeterministic-v4')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Connecting all the dots together - we'll define a graph manager with the Categorial DQN agent parameters, the Atari environment parameters, and the scheduling and visualization parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
    "from rl_coach.base_parameters import VisualizationParameters\n",
    "from rl_coach.environments.gym_environment import atari_schedule\n",
    "\n",
    "graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,\n",
    "                                    schedule_params=atari_schedule, vis_params=VisualizationParameters())\n",
    "graph_manager.visualization_parameters.render = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running the Preset\n",
    "(this is normally done from command line by running ```coach -p Atari_C51 -lvl breakout```)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let the adventure begin\n",
    "graph_manager.improve()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
